package org.feichao.xdao.util;

import org.apache.commons.lang.StringUtils;
import org.feichao.xdao.sql.ParamItem;
import org.feichao.xdao.sql.SqlExpression;

import javax.persistence.Column;
import java.lang.reflect.Field;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.util.*;

/**
 * SQL运行器
 *
 * @author chao
 * @version 2015-05-25
 */
public class SQLRunner {

	SQLContext context;

	public SQLRunner(SQLContext context) {
		this.context = context;
	}

	/**
	 * 运行SQL
	 *
	 * @param returnType 返回类型
	 * @param genericReturnType 如果返回是List，这里则是真正的返回类型，即泛型类型
	 * @return
	 * @throws Exception
	 */
	public Object execute(Class returnType, Class genericReturnType) throws Exception {

		if(StringUtils.startsWithIgnoreCase(context.sqlTemplate, "SELECT")
				|| StringUtils.startsWithIgnoreCase(context.sqlTemplate, "SHOW")){ //查询
			return executeQuery(returnType, genericReturnType);
		}else{ //更新
            String paramName = null;
            Object batchData = null;
            if(context.isBatch){
                Set<Map.Entry<String, Object>> entrySet = context.params.entrySet();
                Map.Entry<String, Object> paramData = entrySet.iterator().next();
                if(paramData.getValue() instanceof List){
                    paramName = paramData.getKey();
                    batchData = paramData.getValue();
                }
            }
            if(context.isBatch){
                return executeBatchUpdate(paramName, (List) batchData);
            }else{
                return executeUpdate();
            }

		}
	}

	private Object executeQuery(Class returnType, Class genericReturnType) throws Exception {
		SQLTemplateParser parser = new SQLTemplateParser(context);
		SQLTemplateParser.Result result = parser.parse();


		JDBCUtil dbUtil = new JDBCUtil(context.dataSourceName);
        //根据result返回的参数顺序，重新排列参数顺序
		List<Object> params = new ArrayList<Object>();
		for (String paramName : result.paramOrder) {
            Object paramObj = context.params.get(paramName);
            if(paramObj instanceof List && !context.isBatch){
                for(Object po: (List)paramObj){
                    params.add(po);
                }
            }else{
                params.add(paramObj);
            }
		}

        if(context.cnd != null){
            for(SqlExpression e: context.cnd.expressionList){
                if(e instanceof ParamItem){
                    params.add(((ParamItem) e).getParam());
                }
            }
        }

		ResultSet rs;
        try{
            rs = dbUtil.executeQuery(result.sql, params.toArray());
        }catch (Exception e){
            throw new RuntimeException(
                    String.format( "Error in execute sql:%s",  result.sql),
                    e);
        }


        //处理ResultSet
		try{
			if (returnType.isAssignableFrom(List.class)) { //处理返回List
				List list = new LinkedList();
				try {
					while (rs.next()) {
						Object data = JDBCUtil.parseResultSet(rs, genericReturnType);
						list.add(data);
					}
				} catch (Exception e) {
					Log.error(e, "Error in parsing result set");
				}
				return list;

			} else if(returnType.equals(long.class) || returnType.equals(Long.class)) { //处理以long形式返回count
                if(rs.next()){
                    return rs.getLong(1);
                }else{
                    Log.warn("SQL statement[%] seems not right.", result.sql);
                    return 0l;
                }
            }else if(returnType.equals(int.class) || returnType.equals(Integer.class)) { //处理以int形式返回count
                if(rs.next()){
                    return rs.getInt(1);
                }else{
                    Log.warn("SQL statement[%] seems not right.", result.sql);
                    return 0;
                }
            }else{ //处理返回单个
				if(rs.next()){
					Object data = JDBCUtil.parseResultSet(rs, returnType);
					return data;
				}else {
					return null;
				}
			}
		}finally{
			dbUtil.close();
		}

	}

	private Integer executeUpdate(){
        SQLTemplateParser parser = new SQLTemplateParser(context);
        SQLTemplateParser.Result result = parser.parse();
        JDBCUtil dbUtil = new JDBCUtil(context.dataSourceName);
        //根据result返回的参数顺序，重新排列参数顺序
        Object[] params = new Object[result.paramOrder.size()];
        int index = 0;
        for (String paramName : result.paramOrder) {
            params[index++] = ExpressionUtil.get(context.params, paramName);
        }

        try{
             return dbUtil.executeUpdate(result.sql, params);
        }catch (Exception e){
            throw new RuntimeException(
                    String.format( "Error in execute sql:%s",  result.sql),
                    e);
        }
	}

    private int[] executeBatchUpdate(String paramName,List data){
        SQLTemplateParser parser = new SQLTemplateParser(context);
        SQLTemplateParser.Result result = parser.parse();
        JDBCUtil dbUtil = new JDBCUtil(context.dataSourceName);

        try{
            PreparedStatement ps = dbUtil.createPreparedStatement(result.sql);
            for(Object obj : data){
                //根据result返回的参数顺序，重新排列参数顺序
                Object[] params = new Object[result.paramOrder.size()];
                Map<String, Object> dataMap = new HashMap();
                dataMap.put(paramName, obj);

                int index = 0;
                for (String po : result.paramOrder) {
                    params[index++] = ExpressionUtil.get(dataMap, po);
                }
                dbUtil.setPreparedStatementParams(ps, params);
                ps.addBatch();
            }
            return ps.executeBatch();
        }catch (Exception e){
            throw new RuntimeException(
                    String.format( "Error in execute sql:%s",  result.sql),
                    e);
        }finally {
            dbUtil.close();
        }

    }


}
